k-Means Clustering

  • In the previous few section, we have explored one category of unsupervised machine learning models: dimensionality reduction.
  • Here we will move on to another class of unsupervised machine learning models: clustering algorithms
  • Clustering algorithms seek to learn, from the properties of the data, an optimal division or discrete labeling of groups of points.
  • Many clustering algorithms are available in Scikit-Learn, but the simplest to understand is k-means clustering.

We begin with the standard imports:


In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()  # for plot styling
import numpy as np

Introducing k-Means

  • The k-means algorithm searches for a pre-determined number of clusters within an unlabeled multidimensional dataset.
  • It accomplishes this using a simple conception of what the optimal clustering looks like:
    • The "cluster center" is the arithmetic mean of all the points belonging to the cluster.
    • Each point is closer to its own cluster center than to other cluster centers.

First, let's generate a two-dimensional dataset containing four distinct blobs:


In [ ]:
from sklearn.datasets.samples_generator import make_blobs
X, y_true = make_blobs(n_samples=300, centers=4,
                       cluster_std=0.60, random_state=0)
plt.scatter(X[:, 0], X[:, 1], s=50);
  • Here, it's relatively easy to pick out the four clusters.
  • The k-means algorithm does this automatically:

In [ ]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=4)
kmeans.fit(X)
y_kmeans = kmeans.predict(X)

Let's visualize the results by plotting the data colored by these labels. We will also plot the cluster centers as determined by the k-means estimator:


In [ ]:
plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis')

centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5);
  • The k-means algorithm (at least in this simple case) assigns the points to clusters very similarly to how we might assign them by eye.
  • You might wonder how this algorithm finds these clusters so quickly!
  • After all, the number of possible combinations of cluster assignments is exponential in the number of data points
  • An exhaustive search would be very, very costly.
  • The approach to k-means clustering is called expectation–maximization.

Expectation–Maximization

  1. Guess some cluster centers
  2. Repeat until converged
    • Assign points to the nearest cluster center
    • Set the cluster centers to the mean

Interactive Example


In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set()  # for plot styling
import numpy as np

from ipywidgets import interact
from sklearn.metrics import pairwise_distances_argmin
from sklearn.datasets.samples_generator import make_blobs

def plot_kmeans_interactive(min_clusters=1, max_clusters=6):
    X, y = make_blobs(n_samples=300, centers=4,
                      random_state=0, cluster_std=0.60)
        
    def plot_points(X, labels, n_clusters):
        plt.scatter(X[:, 0], X[:, 1], c=labels, s=50, cmap='viridis',
                    vmin=0, vmax=n_clusters - 1);
            
    def plot_centers(centers):
        plt.scatter(centers[:, 0], centers[:, 1], marker='o',
                    c=np.arange(centers.shape[0]),
                    s=200, cmap='viridis')
        plt.scatter(centers[:, 0], centers[:, 1], marker='o',
                    c='black', s=50)
        
    def _kmeans_step(frame=0, n_clusters=4):
        rng = np.random.RandomState(2)
        labels = np.zeros(X.shape[0])
        centers = rng.randn(n_clusters, 2)

        nsteps = frame // 3

        for i in range(nsteps + 1):
            old_centers = centers
            if i < nsteps or frame % 3 > 0:
                labels = pairwise_distances_argmin(X, centers)

            if i < nsteps or frame % 3 > 1:
                centers = np.array([X[labels == j].mean(0)
                                    for j in range(n_clusters)])
                nans = np.isnan(centers)
                centers[nans] = old_centers[nans]

        # plot the data and cluster centers
        plot_points(X, labels, n_clusters)
        plot_centers(old_centers)

        # plot new centers if third frame
        if frame % 3 == 2:
            for i in range(n_clusters):
                plt.annotate('', centers[i], old_centers[i], 
                             arrowprops=dict(arrowstyle='->', linewidth=1))
            plot_centers(centers)

        plt.xlim(-4, 4)
        plt.ylim(-2, 10)

        if frame % 3 == 1:
            plt.text(3.8, 9.5, "1. Reassign points to nearest centroid",
                     ha='right', va='top', size=14)
        elif frame % 3 == 2:
            plt.text(3.8, 9.5, "2. Update centroids to cluster means",
                     ha='right', va='top', size=14)
            
    return interact(_kmeans_step, frame=[0, 50],
                    n_clusters=[min_clusters, max_clusters])

plot_kmeans_interactive();

The algorithm is simple enough that we can implement it in just a few lines of code:


In [ ]:
from sklearn.metrics import pairwise_distances_argmin

def find_clusters(X, n_clusters, rseed=2):
    # 1. Randomly choose clusters
    rng = np.random.RandomState(rseed)
    i = rng.permutation(X.shape[0])[:n_clusters]
    centers = X[i]
    
    while True:
        # 2a. Assign labels based on closest center
        labels = pairwise_distances_argmin(X, centers)
        
        # 2b. Find new centers from means of points
        new_centers = np.array([X[labels == i].mean(0)
                                for i in range(n_clusters)])
        
        # 2c. Check for convergence
        if np.all(centers == new_centers):
            break
        centers = new_centers
    
    return centers, labels

centers, labels = find_clusters(X, 4)
plt.scatter(X[:, 0], X[:, 1], c=labels,
            s=50, cmap='viridis');

Issues of the EM algorithm

The globally optimal result may not be achieved
  • There is no assurance that it will lead to the global best solution.
  • For example, if we use a different random seed in our simple procedure, the particular starting guesses lead to poor results:

In [ ]:
centers, labels = find_clusters(X, 4, rseed=0)
plt.scatter(X[:, 0], X[:, 1], c=labels,
            s=50, cmap='viridis');
  • It is usual to run the algorith several times with different starting guesses.
  • Scikit-Learn does by default 10 runs.

The number of clusters must be selected beforehand

  • A challenge with k-means is that you must tell it how many clusters you expect.
  • It cannot learn the number of clusters from the data.
  • For example, if we ask the algorithm to identify six clusters, it will happily proceed and find the best six clusters:

In [ ]:
labels = KMeans(6, random_state=0).fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels,
            s=50, cmap='viridis');

k-means is limited to linear cluster boundaries

  • The model assumptions of k-means means that the algorithm will often be ineffective if the clusters have complicated geometries.
  • The boundaries between k-means clusters will always be linear, which means that it will fail for more complicated boundaries.
  • Consider the following data, along with the cluster labels found by the typical k-means approach:

In [ ]:
from sklearn.datasets import make_moons
X, y = make_moons(200, noise=.05, random_state=0)

In [ ]:
labels = KMeans(2, random_state=0).fit_predict(X)
plt.scatter(X[:, 0], X[:, 1], c=labels,
            s=50, cmap='viridis');

k-means can be slow for large numbers of samples

  • Because each iteration of k-means must access every point in the dataset, the algorithm can be relatively slow as the number of samples grows.

Example 1: k-means on digits

  • Let's take a look at applying k-means on the same simple digits data that we already saw.
  • Here we will attempt to use k-means to try to identify similar digits without using the original label information.

Recall that the digits consist of 1,797 samples with 64 features, where each of the 64 features is the brightness of one pixel in an 8×8 image:


In [ ]:
from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shape

In [ ]:
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits.data)
kmeans.cluster_centers_.shape

What do these clusters look like?


In [ ]:
fig, ax = plt.subplots(2, 5, figsize=(8, 3))
centers = kmeans.cluster_centers_.reshape(10, 8, 8)
for axi, center in zip(ax.flat, centers):
    axi.set(xticks=[], yticks=[])
    axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)

We can have a look at the true labels of the datapoints assigned to a cluster:


In [ ]:
cluster_no = 6
digits.target[(clusters == cluster_no)]
  • Obviously, there is a mismatch between the cluster numbers and the labels.
  • We can fix that using the true labels:

In [ ]:
from scipy.stats import mode

labels = np.zeros_like(clusters)
for i in range(10):
    mask = (clusters == i)
    labels[mask] = mode(digits.target[mask])[0]

Now we can check how accurate our unsupervised clustering was in finding similar digits within the data:


In [ ]:
from sklearn.metrics import accuracy_score
accuracy_score(digits.target, labels)

With just a simple k-means algorithm, we discovered the correct grouping for 80% of the input digits! Let's check the confusion matrix for this:


In [ ]:
from sklearn.metrics import confusion_matrix
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
            xticklabels=digits.target_names,
            yticklabels=digits.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');

Just for fun

  • We can use the t-distributed stochastic neighbor embedding (t-SNE) algorithm to pre-process the data before performing k-means.
  • t-SNE is a nonlinear embedding algorithm that is particularly adept at preserving points within clusters.

Let's see how it does:


In [ ]:
from sklearn.manifold import TSNE

# Project the data: this step will take several seconds
tsne = TSNE(n_components=2, init='random', random_state=0)
digits_proj = tsne.fit_transform(digits.data)

# Compute the clusters
kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits_proj)

# Permute the labels
labels = np.zeros_like(clusters)
for i in range(10):
    mask = (clusters == i)
    labels[mask] = mode(digits.target[mask])[0]

# Compute the accuracy
accuracy_score(digits.target, labels)

In [ ]:
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
            xticklabels=digits.target_names,
            yticklabels=digits.target_names)
plt.xlabel('true label')
plt.ylabel('predicted label');